{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a href=\"https://colab.research.google.com/github/run-llama/llama_index/blob/main/docs/examples/finetuning/gradient/gradient_text2sql.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fine Tuning for Text-to-SQL With Gradient and LlamaIndex\n",
    "\n",
    "In this notebook we show you how to fine-tune llama2-7b on the [sql-create-context](https://huggingface.co/datasets/b-mc2/sql-create-context) dataset to be better at Text-to-SQL.\n",
    "\n",
    "We do this by using [gradient.ai](https://gradient.ai)\n",
    "\n",
    "**NOTE**: This is an alternative to our repo/guide on fine-tuning llama2-7b with Modal: https://github.com/run-llama/modal_finetune_sql"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install llama-index-llms-gradient\n",
    "%pip install llama-index-finetuning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install llama-index gradientai -q"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from llama_index.llms.gradient import GradientBaseModelLLM\n",
    "from llama_index.finetuning import GradientFinetuneEngine"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.environ[\"GRADIENT_ACCESS_TOKEN\"] = os.getenv(\"GRADIENT_API_KEY\")\n",
    "os.environ[\"GRADIENT_WORKSPACE_ID\"] = \"\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prepare Data\n",
    "\n",
    "We load sql-create-context from Hugging Face datasets. The dataset is a mix of WikiSQL and Spider, and is organized in the format of input query, context, and ground-truth SQL statement. The context is a CREATE TABLE statement."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dialect = \"sqlite\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load Data, Save to a Directory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "from pathlib import Path\n",
    "import json\n",
    "\n",
    "\n",
    "def load_jsonl(data_dir):\n",
    "    data_path = Path(data_dir).as_posix()\n",
    "    data = load_dataset(\"json\", data_files=data_path)\n",
    "    return data\n",
    "\n",
    "\n",
    "def save_jsonl(data_dicts, out_path):\n",
    "    with open(out_path, \"w\") as fp:\n",
    "        for data_dict in data_dicts:\n",
    "            fp.write(json.dumps(data_dict) + \"\\n\")\n",
    "\n",
    "\n",
    "def load_data_sql(data_dir: str = \"data_sql\"):\n",
    "    dataset = load_dataset(\"b-mc2/sql-create-context\")\n",
    "\n",
    "    dataset_splits = {\"train\": dataset[\"train\"]}\n",
    "    out_path = Path(data_dir)\n",
    "\n",
    "    out_path.parent.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "    for key, ds in dataset_splits.items():\n",
    "        with open(out_path, \"w\") as f:\n",
    "            for item in ds:\n",
    "                newitem = {\n",
    "                    \"input\": item[\"question\"],\n",
    "                    \"context\": item[\"context\"],\n",
    "                    \"output\": item[\"answer\"],\n",
    "                }\n",
    "                f.write(json.dumps(newitem) + \"\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# dump data to data_sql\n",
    "load_data_sql(data_dir=\"data_sql\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Split into Training/Validation Splits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from math import ceil\n",
    "\n",
    "\n",
    "def get_train_val_splits(\n",
    "    data_dir: str = \"data_sql\",\n",
    "    val_ratio: float = 0.1,\n",
    "    seed: int = 42,\n",
    "    shuffle: bool = True,\n",
    "):\n",
    "    data = load_jsonl(data_dir)\n",
    "    num_samples = len(data[\"train\"])\n",
    "    val_set_size = ceil(val_ratio * num_samples)\n",
    "\n",
    "    train_val = data[\"train\"].train_test_split(\n",
    "        test_size=val_set_size, shuffle=shuffle, seed=seed\n",
    "    )\n",
    "    return train_val[\"train\"].shuffle(), train_val[\"test\"].shuffle()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.008841991424560547,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": 28,
       "postfix": null,
       "prefix": "Downloading data files",
       "rate": null,
       "total": 1,
       "unit": "it",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "943f0cac2df34115b3f2a94615d806b8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.004725933074951172,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": 28,
       "postfix": null,
       "prefix": "Extracting data files",
       "rate": null,
       "total": 1,
       "unit": "it",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "acefb752a3b947768c6c0a9281352880",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.004148006439208984,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": 28,
       "postfix": null,
       "prefix": "Generating train split",
       "rate": null,
       "total": 0,
       "unit": " examples",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "70e5c14e9b0e4bcd9f35014a9b0e1ff6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating train split: 0 examples [00:00, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "raw_train_data, raw_val_data = get_train_val_splits(data_dir=\"data_sql\")\n",
    "save_jsonl(raw_train_data, \"train_data_raw.jsonl\")\n",
    "save_jsonl(raw_val_data, \"val_data_raw.jsonl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input': 'If the record is 5-5, what is the game maximum?',\n",
       " 'context': 'CREATE TABLE table_23285805_4 (game INTEGER, record VARCHAR)',\n",
       " 'output': 'SELECT MAX(game) FROM table_23285805_4 WHERE record = \"5-5\"'}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "raw_train_data[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Map Training/Dataset Dictionaries to Prompts\n",
    "\n",
    "Here we define functions to map the dataset dictionaries to a prompt format, that we can then feed to gradient.ai's fine-tuning endpoint."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### Format is similar to the nous-hermes LLMs\n",
    "\n",
    "text_to_sql_tmpl_str = \"\"\"\\\n",
    "<s>### Instruction:\\n{system_message}{user_message}\\n\\n### Response:\\n{response}</s>\"\"\"\n",
    "\n",
    "text_to_sql_inference_tmpl_str = \"\"\"\\\n",
    "<s>### Instruction:\\n{system_message}{user_message}\\n\\n### Response:\\n\"\"\"\n",
    "\n",
    "### Alternative Format\n",
    "### Recommended by gradient.ai docs, but empirically we found worse results here\n",
    "\n",
    "# text_to_sql_tmpl_str = \"\"\"\\\n",
    "# <s>[INST] SYS\\n{system_message}\\n<</SYS>>\\n\\n{user_message} [/INST] {response} </s>\"\"\"\n",
    "\n",
    "# text_to_sql_inference_tmpl_str = \"\"\"\\\n",
    "# <s>[INST] SYS\\n{system_message}\\n<</SYS>>\\n\\n{user_message} [/INST] \"\"\"\n",
    "\n",
    "\n",
    "def _generate_prompt_sql(input, context, dialect=\"sqlite\", output=\"\"):\n",
    "    system_message = f\"\"\"You are a powerful text-to-SQL model. Your job is to answer questions about a database. You are given a question and context regarding one or more tables. \n",
    "\n",
    "You must output the SQL query that answers the question.\n",
    "    \n",
    "    \"\"\"\n",
    "    user_message = f\"\"\"### Dialect:\n",
    "{dialect}\n",
    "\n",
    "### Input:\n",
    "{input}\n",
    "\n",
    "### Context:\n",
    "{context}\n",
    "\n",
    "### Response:\n",
    "\"\"\"\n",
    "    if output:\n",
    "        return text_to_sql_tmpl_str.format(\n",
    "            system_message=system_message,\n",
    "            user_message=user_message,\n",
    "            response=output,\n",
    "        )\n",
    "    else:\n",
    "        return text_to_sql_inference_tmpl_str.format(\n",
    "            system_message=system_message, user_message=user_message\n",
    "        )\n",
    "\n",
    "\n",
    "def generate_prompt(data_point):\n",
    "    full_prompt = _generate_prompt_sql(\n",
    "        data_point[\"input\"],\n",
    "        data_point[\"context\"],\n",
    "        dialect=\"sqlite\",\n",
    "        output=data_point[\"output\"],\n",
    "    )\n",
    "    return {\"inputs\": full_prompt}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = [\n",
    "    {\"inputs\": d[\"inputs\"] for d in raw_train_data.map(generate_prompt)}\n",
    "]\n",
    "save_jsonl(train_data, \"train_data.jsonl\")\n",
    "val_data = [{\"inputs\": d[\"inputs\"] for d in raw_val_data.map(generate_prompt)}]\n",
    "save_jsonl(val_data, \"val_data.jsonl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<s>### Instruction:\n",
      "You are a powerful text-to-SQL model. Your job is to answer questions about a database. You are given a question and context regarding one or more tables. \n",
      "\n",
      "You must output the SQL query that answers the question.\n",
      "    \n",
      "    ### Dialect:\n",
      "sqlite\n",
      "\n",
      "### Input:\n",
      "Who had the fastest lap in bowmanville, ontario?\n",
      "\n",
      "### Context:\n",
      "CREATE TABLE table_30134667_2 (fastest_lap VARCHAR, location VARCHAR)\n",
      "\n",
      "### Response:\n",
      "\n",
      "\n",
      "### Response:\n",
      "SELECT fastest_lap FROM table_30134667_2 WHERE location = \"Bowmanville, Ontario\"</s>\n"
     ]
    }
   ],
   "source": [
    "print(train_data[0][\"inputs\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run Fine-tuning with gradient.ai\n",
    "\n",
    "Here we call Gradient's fine-tuning endpoint with the `GradientFinetuneEngine`. \n",
    "\n",
    "We limit the steps for example purposes, but feel free to modify the parameters as you wish. \n",
    "\n",
    "At the end we fetch our fine-tuned LLM."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# base_model_slug = \"nous-hermes2\"\n",
    "base_model_slug = \"llama2-7b-chat\"\n",
    "base_llm = GradientBaseModelLLM(\n",
    "    base_model_slug=base_model_slug, max_tokens=300\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# step max steps to 20 just for testing purposes\n",
    "# NOTE: can only specify one of base_model_slug or model_adapter_id\n",
    "finetune_engine = GradientFinetuneEngine(\n",
    "    base_model_slug=base_model_slug,\n",
    "    # model_adapter_id='805c6fd6-daa8-4fc8-a509-bebb2f2c1024_model_adapter',\n",
    "    name=\"text_to_sql\",\n",
    "    data_path=\"train_data.jsonl\",\n",
    "    verbose=True,\n",
    "    max_steps=200,\n",
    "    batch_size=4,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'805c6fd6-daa8-4fc8-a509-bebb2f2c1024_model_adapter'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "finetune_engine.model_adapter_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "epochs = 1\n",
    "for i in range(epochs):\n",
    "    print(f\"** EPOCH {i} **\")\n",
    "    finetune_engine.finetune()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ft_llm = finetune_engine.get_finetuned_model(max_tokens=300)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluation\n",
    "\n",
    "This is two parts:\n",
    "1. We evaluate on some sample datapoints in the validation dataset.\n",
    "2. We evaluate on a new toy SQL dataset, and plug the fine-tuned LLM into our `NLSQLTableQueryEngine` to run a full text-to-SQL workflow.\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Part 1: Evaluation on Validation Dataset Datapoints"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_text2sql_completion(llm, raw_datapoint):\n",
    "    text2sql_tmpl_str = _generate_prompt_sql(\n",
    "        raw_datapoint[\"input\"],\n",
    "        raw_datapoint[\"context\"],\n",
    "        dialect=\"sqlite\",\n",
    "        output=None,\n",
    "    )\n",
    "\n",
    "    response = llm.complete(text2sql_tmpl_str)\n",
    "    return str(response)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input': ' how many\\xa0reverse\\xa0with\\xa0series\\xa0being iii series',\n",
       " 'context': 'CREATE TABLE table_12284476_8 (reverse VARCHAR, series VARCHAR)',\n",
       " 'output': 'SELECT COUNT(reverse) FROM table_12284476_8 WHERE series = \"III series\"'}"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "test_datapoint = raw_val_data[2]\n",
    "display(test_datapoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# run base llama2-7b-chat model\n",
    "get_text2sql_completion(base_llm, test_datapoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'SELECT MIN(year) FROM table_name_35 WHERE venue = \"barcelona, spain\"'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# run fine-tuned llama2-7b-chat model\n",
    "get_text2sql_completion(ft_llm, test_datapoint)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Part 2: Evaluation on a Toy Dataset\n",
    "\n",
    "Here we create a toy table of cities and their populations."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Create Table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create sample\n",
    "from sqlalchemy import (\n",
    "    create_engine,\n",
    "    MetaData,\n",
    "    Table,\n",
    "    Column,\n",
    "    String,\n",
    "    Integer,\n",
    "    select,\n",
    "    column,\n",
    ")\n",
    "from llama_index.core import SQLDatabase"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "engine = create_engine(\"sqlite:///:memory:\")\n",
    "metadata_obj = MetaData()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create city SQL table\n",
    "table_name = \"city_stats\"\n",
    "city_stats_table = Table(\n",
    "    table_name,\n",
    "    metadata_obj,\n",
    "    Column(\"city_name\", String(16), primary_key=True),\n",
    "    Column(\"population\", Integer),\n",
    "    Column(\"country\", String(16), nullable=False),\n",
    ")\n",
    "metadata_obj.create_all(engine)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "CREATE TABLE city_stats (\n",
      "\tcity_name VARCHAR(16) NOT NULL, \n",
      "\tpopulation INTEGER, \n",
      "\tcountry VARCHAR(16) NOT NULL, \n",
      "\tPRIMARY KEY (city_name)\n",
      ")\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# This context is used later on\n",
    "from sqlalchemy.schema import CreateTable\n",
    "\n",
    "table_create_stmt = str(CreateTable(city_stats_table))\n",
    "print(table_create_stmt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sql_database = SQLDatabase(engine, include_tables=[\"city_stats\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Populate with Test Datapoints"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# insert sample rows\n",
    "from sqlalchemy import insert\n",
    "\n",
    "rows = [\n",
    "    {\"city_name\": \"Toronto\", \"population\": 2930000, \"country\": \"Canada\"},\n",
    "    {\"city_name\": \"Tokyo\", \"population\": 13960000, \"country\": \"Japan\"},\n",
    "    {\n",
    "        \"city_name\": \"Chicago\",\n",
    "        \"population\": 2679000,\n",
    "        \"country\": \"United States\",\n",
    "    },\n",
    "    {\"city_name\": \"Seoul\", \"population\": 9776000, \"country\": \"South Korea\"},\n",
    "]\n",
    "for row in rows:\n",
    "    stmt = insert(city_stats_table).values(**row)\n",
    "    with engine.connect() as connection:\n",
    "        cursor = connection.execute(stmt)\n",
    "        connection.commit()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Get Text2SQL Query Engine"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core.query_engine import NLSQLTableQueryEngine\n",
    "from llama_index.core import PromptTemplate\n",
    "\n",
    "\n",
    "def get_text2sql_query_engine(llm, table_context, sql_database):\n",
    "    # we essentially swap existing template variables for new template variables\n",
    "    # to put into our `NLSQLTableQueryEngine`\n",
    "    text2sql_tmpl_str = _generate_prompt_sql(\n",
    "        \"{query_str}\", \"{schema}\", dialect=\"{dialect}\", output=\"\"\n",
    "    )\n",
    "    sql_prompt = PromptTemplate(text2sql_tmpl_str)\n",
    "    # Here we explicitly set the table context to be the CREATE TABLE string\n",
    "    # So we set `tables` to empty, and hard fix `context_str` prefix\n",
    "\n",
    "    query_engine = NLSQLTableQueryEngine(\n",
    "        sql_database,\n",
    "        tables=[],\n",
    "        context_str_prefix=table_context,\n",
    "        text_to_sql_prompt=sql_prompt,\n",
    "        llm=llm,\n",
    "        synthesize_response=False,\n",
    "    )\n",
    "    return query_engine"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# query = \"Which cities have populations less than 10 million people?\"\n",
    "query = \"What is the population of Tokyo? (make sure cities/countries are capitalized)\"\n",
    "# query = \"What is the average population and total population of the cities?\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Results with base llama2 model\n",
    "The base llama2 model appends a bunch of text to the SQL statement that breaks our parser (and has minor capitalization mistakes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "base_query_engine = get_text2sql_query_engine(\n",
    "    base_llm, table_create_stmt, sql_database\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "base_response = base_query_engine.query(query)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Error: You can only execute one statement at a time.\n"
     ]
    }
   ],
   "source": [
    "print(str(base_response))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\"SELECT population FROM city_stats WHERE country = 'JAPAN';\\n\\nThis will return the population of Tokyo, which is the only city in the table with a population value.\""
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "base_response.metadata[\"sql_query\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Results with fine-tuned model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ft_query_engine = get_text2sql_query_engine(\n",
    "    ft_llm, table_create_stmt, sql_database\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ft_response = ft_query_engine.query(query)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[(13960000,)]\n"
     ]
    }
   ],
   "source": [
    "print(str(ft_response))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'SELECT population FROM city_stats WHERE country = \"Japan\" AND city_name = \"Tokyo\"'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ft_response.metadata[\"sql_query\"]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llama_index_v2",
   "language": "python",
   "name": "llama_index_v2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
